import scipy
import numpy as np
import sklearn.decomposition
import logging
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib
import itertools

class BehaviorPlot(object):
    """
    Plotting behavior data in a virtual behavior space
    """
    colors = ['r', 'b', 'y', 'g', 'c', 'm', 'k']

    def __init__(self):
        try:
            #Standardize the data
            pass
        except Exception as e:
            logging.error('Generating BehaviorPlot failed! %s', e)

    @staticmethod
    def _zScore(data):
        """
        Calculate the z-score on the data and filter NaN (set zeros)
        """
        zdata = data.apply(scipy.stats.mstats.zscore)
        if zdata.isnull().any().any():
            logging.warning('Filtering NaN in standardized data!')
            zdata = zdata.apply(np.nan_to_num)

        return zdata

    @staticmethod
    def _getColors(ids):
        return zip(itertools.count(), ids, itertools.cycle(BehaviorPlot.colors))

    @staticmethod
    def _importData(data):
        try:
            #Calculate the amount of goal. Where Lick=1, NosePoke=2, InZone=3, OutZone=4
            goalStates = data[['Lick', 'NosePoke', 'In zone']].astype('bool')
            goalStates['OutZone'] = True ^ goalStates['In zone']
            goalVector = goalStates.apply(lambda x: 5 - max(x['OutZone']*1, x['In zone']*2, x['NosePoke']*3, x['Lick']*4), axis=1)
            mobilityVector = data['Velocity']
            #mobilityVector = data['Distance moved']

            coords = pd.concat([goalVector, mobilityVector], axis=1)
            coords.columns = ['GoalVector', 'Mobility']
            return coords
        except Exception as e:
            logging.error('Importing data failed! %s', e)
        return data

    def plot(self, data, title, plottype='scatter', zlim=(0, 50)):
        """
        Generate a scatter or path plot in PCA space
        """
        data = self._importData(data)
        f = plt.figure()
        ax = f.add_subplot(111, projection='3d')
        X = data.index.values
        Y = data.iloc[:, 0].values
        Z = data.iloc[:, 1].values
        if plottype == 'scatter':
            ax.scatter(X, Y, Z, c='b', alpha=0.5)
        elif plottype == 'path':
            L = len(X)
            for n in range(L - 1):
                #ax.plot(X[n:n+2], Y[n:n+2], Z[n:n+2], alpha=float(n)/(L-1), color='b')
                ax.plot(X[n:n+2], Y[n:n+2], Z[n:n+2], color='b')

        ax.set_yticklabels(['', 'Lick', 'NosePoke', 'In Zone', 'Out Zone'])
        ax.set_yticks([0, 1, 2, 3, 4])

        ax.set_xlabel('Time')
        ax.set_ylabel(data.columns.tolist()[0])
        ax.set_zlabel(data.columns.tolist()[1])

        ax.set_zlim(zlim)

        ax.legend()

        ax.set_title(title)
        f.show()

